"""
Phase 5: PPML (Poisson Pseudo-Maximum Likelihood) robustness for gravity models.

Estimates PPML versions of Models 2b and 2c from the gravity bilateral paper.
PPML uses the LEVEL of bilateral portfolio positions as the dependent variable
(not log-transformed), which properly handles zeros and is consistent under
heteroskedasticity (Santos Silva & Tenreyro, 2006).

Output: gravity_bilateral/output/tables/ppml_results.csv
"""

import pandas as pd
import numpy as np
import statsmodels.api as sm
from statsmodels.genmod.generalized_linear_model import GLM
import warnings
import time

warnings.filterwarnings("ignore", category=sm.tools.sm_exceptions.ConvergenceWarning)

# ---------- paths ----------
DATA = "/mnt/c/demographics_capital_flows/gravity_bilateral/data/processed/bilateral_panel.csv"
OLS_RESULTS = "/mnt/c/demographics_capital_flows/gravity_bilateral/output/tables/gravity_results.csv"
OUT = "/mnt/c/demographics_capital_flows/gravity_bilateral/output/tables/ppml_results.csv"

# ---------- load data ----------
print("Loading bilateral panel...")
df = pd.read_csv(DATA)
print(f"  Raw panel: {len(df):,} rows, {df.shape[1]} columns")

# ---------- dependent variable ----------
# Use portfolio_total in levels. Drop negatives (161 obs) and NaNs.
# Scale by 1 million USD to help Poisson convergence.
df = df[df["portfolio_total"].notna()].copy()
df = df[df["portfolio_total"] >= 0].copy()
df["y_millions"] = df["portfolio_total"] / 1e6

print(f"  After dropping NaN/negative portfolio_total: {len(df):,}")
print(f"  Zeros: {(df['y_millions'] == 0).sum():,}")
print(f"  Positive: {(df['y_millions'] > 0).sum():,}")
print(f"  Dep var (millions USD) — mean: {df['y_millions'].mean():.1f}, "
      f"median: {df['y_millions'].median():.1f}, max: {df['y_millions'].max():.0f}")

# ---------- year dummies ----------
df["year"] = df["year"].astype(int)
year_dummies = pd.get_dummies(df["year"], prefix="yr", drop_first=True, dtype=float)
yr_cols = list(year_dummies.columns)

# ---------- model specifications ----------
base_vars = ["log_dist", "contiguity", "common_lang_official", "colonial_ties", "log_gdp_product"]
demog_vars = ["dZ_1", "dZ_2", "dZ_3"]
kaopen_vars = ["kaopen_j", "dZ_1_x_kaopen_j", "dZ_2_x_kaopen_j", "dZ_3_x_kaopen_j"]

model_specs = {
    "PPML 2b: Gravity + Demographics": base_vars + demog_vars,
    "PPML 2c: Gravity + Demographics + KAOPEN": base_vars + demog_vars + kaopen_vars,
}

# ---------- estimate ----------
all_results = []

for model_name, xvars in model_specs.items():
    print(f"\n{'='*70}")
    print(f"Estimating {model_name}")
    print(f"{'='*70}")

    all_rhs = xvars + yr_cols
    cols_needed = ["y_millions"] + xvars
    sub = df.dropna(subset=cols_needed).copy()

    # Merge year dummies
    sub = sub.join(year_dummies)

    # If sample > 120K, take 50% random subsample for speed
    subsample_note = ""
    if len(sub) > 120_000:
        np.random.seed(42)
        sub = sub.sample(frac=0.5).copy()
        subsample_note = " [50% subsample]"
        print(f"  Large sample — using 50% random subsample")

    y = sub["y_millions"].values
    X = sub[all_rhs].values
    X = sm.add_constant(X)
    xnames = ["const"] + all_rhs

    n_obs = len(y)
    n_zero = (y == 0).sum()
    print(f"  N = {n_obs:,} (zeros: {n_zero:,}, positive: {n_obs - n_zero:,}){subsample_note}")

    t0 = time.time()
    try:
        poisson_model = GLM(y, X, family=sm.families.Poisson())
        result = poisson_model.fit(maxiter=200, method="IRLS")
        elapsed = time.time() - t0
        print(f"  Converged in {elapsed:.1f}s, {result.nobs:.0f} obs")
    except Exception as e:
        print(f"  IRLS failed ({e}), trying Newton-Raphson...")
        try:
            result = poisson_model.fit(maxiter=200, method="newton")
            elapsed = time.time() - t0
            print(f"  Newton converged in {elapsed:.1f}s")
        except Exception as e2:
            print(f"  Newton also failed: {e2}")
            continue

    # Pseudo R-squared (McFadden)
    null_model = GLM(y, sm.add_constant(np.ones(len(y))), family=sm.families.Poisson())
    null_result = null_model.fit(maxiter=100)
    pseudo_r2 = 1 - (result.llf / null_result.llf)

    # Store results for non-year-dummy variables
    for i, var in enumerate(xnames):
        if var.startswith("yr_"):
            continue
        row = {
            "model": model_name + subsample_note,
            "variable": var,
            "coefficient": result.params[i],
            "std_error": result.bse[i],
            "z_stat": result.tvalues[i],
            "p_value": result.pvalues[i],
        }
        all_results.append(row)

    # Metadata rows
    for meta_var, meta_val in [("_Pseudo_R2", pseudo_r2), ("_N_obs", n_obs),
                                ("_N_zeros", n_zero), ("_Log_Likelihood", result.llf)]:
        all_results.append({
            "model": model_name + subsample_note,
            "variable": meta_var,
            "coefficient": meta_val,
            "std_error": np.nan,
            "z_stat": np.nan,
            "p_value": np.nan,
        })

    # Print summary
    print(f"\n  {'Variable':<30} {'Coef':>10} {'SE':>10} {'z':>8} {'p':>8}")
    print(f"  {'-'*66}")
    for i, var in enumerate(xnames):
        if var.startswith("yr_"):
            continue
        sig = ""
        p = result.pvalues[i]
        if p < 0.001: sig = "***"
        elif p < 0.01: sig = "**"
        elif p < 0.05: sig = "*"
        elif p < 0.1: sig = "+"
        print(f"  {var:<30} {result.params[i]:>10.4f} {result.bse[i]:>10.4f} "
              f"{result.tvalues[i]:>8.3f} {result.pvalues[i]:>8.4f} {sig}")
    print(f"\n  Pseudo R² = {pseudo_r2:.4f}")
    print(f"  Log-likelihood = {result.llf:.1f}")
    print(f"  Year dummies: {len(yr_cols)} included")

# ---------- save ----------
results_df = pd.DataFrame(all_results)
results_df.to_csv(OUT, index=False)
print(f"\nResults saved to {OUT}")

# ---------- comparison table ----------
print("\n" + "=" * 90)
print("COMPARISON: OLS (log) vs PPML (levels) — Demographic Variables")
print("=" * 90)

ols = pd.read_csv(OLS_RESULTS)

# Get OLS coefficients for models 2b and 2c
ols_2b = ols[ols["model"] == "2b: Gravity + Demographics"]
ols_2c = ols[ols["model"] == "2c: Gravity + Demographics + KAOPEN interactions"]

compare_vars = ["dZ_1", "dZ_2", "dZ_3", "kaopen_j",
                "dZ_1_x_kaopen_j", "dZ_2_x_kaopen_j", "dZ_3_x_kaopen_j"]

print(f"\n{'Variable':<25} {'OLS 2b':>10} {'PPML 2b':>10}  |  {'OLS 2c':>10} {'PPML 2c':>10}")
print("-" * 80)

ppml_2b = results_df[results_df["model"].str.contains("PPML 2b")]
ppml_2c = results_df[results_df["model"].str.contains("PPML 2c")]

for var in compare_vars:
    ols_2b_val = ols_2b.loc[ols_2b["variable"] == var, "coefficient"]
    ols_2c_val = ols_2c.loc[ols_2c["variable"] == var, "coefficient"]
    ppml_2b_val = ppml_2b.loc[ppml_2b["variable"] == var, "coefficient"]
    ppml_2c_val = ppml_2c.loc[ppml_2c["variable"] == var, "coefficient"]

    ols_2b_p = ols_2b.loc[ols_2b["variable"] == var, "p_value"]
    ols_2c_p = ols_2c.loc[ols_2c["variable"] == var, "p_value"]
    ppml_2b_p = ppml_2b.loc[ppml_2b["variable"] == var, "p_value"]
    ppml_2c_p = ppml_2c.loc[ppml_2c["variable"] == var, "p_value"]

    def fmt(val_series, p_series):
        if len(val_series) == 0:
            return "      —   "
        v = val_series.iloc[0]
        p = p_series.iloc[0]
        sig = "***" if p < 0.001 else "**" if p < 0.01 else "*" if p < 0.05 else "+" if p < 0.1 else ""
        return f"{v:>8.4f}{sig:<2}"

    row = (f"{var:<25} {fmt(ols_2b_val, ols_2b_p):>10} {fmt(ppml_2b_val, ppml_2b_p):>10}"
           f"  |  {fmt(ols_2c_val, ols_2c_p):>10} {fmt(ppml_2c_val, ppml_2c_p):>10}")
    print(row)

# Also show gravity variables
print("\n--- Gravity variables ---")
gravity_vars = ["log_dist", "contiguity", "common_lang_official", "colonial_ties", "log_gdp_product"]
for var in gravity_vars:
    ols_2b_val = ols_2b.loc[ols_2b["variable"] == var, "coefficient"]
    ols_2c_val = ols_2c.loc[ols_2c["variable"] == var, "coefficient"]
    ppml_2b_val = ppml_2b.loc[ppml_2b["variable"] == var, "coefficient"]
    ppml_2c_val = ppml_2c.loc[ppml_2c["variable"] == var, "coefficient"]

    ols_2b_p = ols_2b.loc[ols_2b["variable"] == var, "p_value"]
    ols_2c_p = ols_2c.loc[ols_2c["variable"] == var, "p_value"]
    ppml_2b_p = ppml_2b.loc[ppml_2b["variable"] == var, "p_value"]
    ppml_2c_p = ppml_2c.loc[ppml_2c["variable"] == var, "p_value"]

    def fmt(val_series, p_series):
        if len(val_series) == 0:
            return "      —   "
        v = val_series.iloc[0]
        p = p_series.iloc[0]
        sig = "***" if p < 0.001 else "**" if p < 0.01 else "*" if p < 0.05 else "+" if p < 0.1 else ""
        return f"{v:>8.4f}{sig:<2}"

    row = (f"{var:<25} {fmt(ols_2b_val, ols_2b_p):>10} {fmt(ppml_2b_val, ppml_2b_p):>10}"
           f"  |  {fmt(ols_2c_val, ols_2c_p):>10} {fmt(ppml_2c_val, ppml_2c_p):>10}")
    print(row)

# Metadata comparison
print("\n--- Model fit ---")
ols_2b_r2 = ols_2b.loc[ols_2b["variable"] == "_R_squared", "coefficient"].iloc[0]
ols_2c_r2 = ols_2c.loc[ols_2c["variable"] == "_R_squared", "coefficient"].iloc[0]
ppml_2b_r2 = ppml_2b.loc[ppml_2b["variable"] == "_Pseudo_R2", "coefficient"].iloc[0]
ppml_2c_r2 = ppml_2c.loc[ppml_2c["variable"] == "_Pseudo_R2", "coefficient"].iloc[0]

ols_2b_n = ols_2b.loc[ols_2b["variable"] == "_N_obs", "coefficient"].iloc[0]
ols_2c_n = ols_2c.loc[ols_2c["variable"] == "_N_obs", "coefficient"].iloc[0]
ppml_2b_n = ppml_2b.loc[ppml_2b["variable"] == "_N_obs", "coefficient"].iloc[0]
ppml_2c_n = ppml_2c.loc[ppml_2c["variable"] == "_N_obs", "coefficient"].iloc[0]

print(f"{'R² / Pseudo R²':<25} {ols_2b_r2:>10.4f} {ppml_2b_r2:>10.4f}  |  {ols_2c_r2:>10.4f} {ppml_2c_r2:>10.4f}")
print(f"{'N observations':<25} {ols_2b_n:>10.0f} {ppml_2b_n:>10.0f}  |  {ols_2c_n:>10.0f} {ppml_2c_n:>10.0f}")

ppml_2b_nz = ppml_2b.loc[ppml_2b["variable"] == "_N_zeros", "coefficient"].iloc[0]
ppml_2c_nz = ppml_2c.loc[ppml_2c["variable"] == "_N_zeros", "coefficient"].iloc[0]
print(f"{'N zeros (PPML only)':<25} {'—':>10} {ppml_2b_nz:>10.0f}  |  {'—':>10} {ppml_2c_nz:>10.0f}")

print("\nNote: OLS uses log(portfolio_total) on positive obs only;")
print("      PPML uses portfolio_total / 1M in levels, including zeros.")
print("      PPML coefficients are semi-elasticities (% change in E[y]).")
print("Done.")
